from typing import Iterable, Sequence, Set

from common.credentials import Credentials, LMHash, NTHash, Password, Username
from infection_monkey.exploit.tools import (
    generate_brute_force_credentials,
    identity_type_filter,
    secret_type_filter,
)


def generate_rdp_credentials(
    credentials: Sequence[Credentials], domains: Sequence[str], running_from_windows: bool
) -> Sequence[Credentials]:
    brute_force_credentials = generate_brute_force_credentials(
        credentials,
        identity_filter=identity_type_filter([Username]),
        secret_filter=secret_type_filter([Password, LMHash, NTHash]),
    )
    rdp_credentials = list(
        _add_domains_to_usernames(brute_force_credentials, domains, running_from_windows)
    )

    return _remove_duplicate_credentials(rdp_credentials)


def _add_domains_to_usernames(
    credentials: Sequence[Credentials], domains: Sequence[str], running_from_windows: bool
) -> Iterable[Credentials]:
    local_domains = _get_local_domains(running_from_windows)
    all_domains = [*domains, *local_domains]

    for credential in credentials:
        if credential.identity is None:
            continue

        if "\\" in credential.identity.username:
            yield credential

        for domain in all_domains:
            yield Credentials(
                identity=Username(username=f"{domain}\\{credential.identity.username}"),
                secret=credential.secret,
            )


def _get_local_domains(running_from_windows: bool) -> Set[str]:
    local_domains = set()
    if running_from_windows:
        local_domains.add(".")
        local_domains.add(_get_local_machine_domain())

    return local_domains


def _get_local_machine_domain() -> str:
    import win32api

    return win32api.GetUserNameEx(win32api.NameSamCompatible).split("\\")[0]


def _remove_duplicate_credentials(credentials: Sequence[Credentials]) -> Sequence[Credentials]:
    # Using a dict, not a set, to preserve order
    credentials_dict = dict.fromkeys(credentials, None)

    return list(credentials_dict.keys())
